﻿using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using HAPI;
using HuginApiAddonsCS.I_DID;
using HuginApiAddonsCS.Policy;
using HuginApiAddonsCS.Policy.Learning;

namespace HuginApiAddonsTestCS.Programs
{
    class LearnTrees : IProgram
    {
        #region IProgram Members

        public void ExecuteProgram(string[] args)
        {
            Stopwatch _stopwatch = new Stopwatch();
            //_stopwatch.Start();

            if (args.Count() >= 7)
            {
                string inputCsv = args[1];
                string obsCol = args[2];
                List<string> observations = args[3].Split(',').ToList();
                string actionCol = args[4];
                string modelCol = args[5];
                string timeStepCol = args[6];

                PolicyTreeLearner learner = new PolicyTreeLearner();
                List<PolicyNode> trees = learner.LearnFromKnownHorizon(inputCsv, obsCol, actionCol, modelCol, timeStepCol, observations);

                

#if DEBUG
                //int currentTree = 0;
                //foreach (PolicyNode rootNode in trees)
                //{
                //    Console.Out.WriteLine(rootNode.ToTreeString(string.Empty, currentTree));
                //    currentTree++;
                //}
#endif

                if (args.Count() == 11)
                {
                    string randomFillOutputFile = args[7];
                    string compatibleFillOutputFile = args[8];
                    double epsilon = Convert.ToDouble(args[10]);
                    
                    int currT= 0;

                    List<string> actions = args[9].Split(',').ToList();

                    _stopwatch.Start();

                    for (int i = 0; i < 100; i++)
                    {
                        trees = learner.LearnFromKnownHorizon(inputCsv, obsCol, actionCol, modelCol, timeStepCol, observations);
                        StreamWriter sbR = new StreamWriter(randomFillOutputFile);
                        foreach (PolicyNode rootNode in trees)
                        {
                            rootNode.RandomFill(new Random(), rootNode.GetTreeLength(), actions);
                            sbR.WriteLine(rootNode.ToTreeString(string.Empty, currT));
                            currT++;
                        }
                        sbR.Close();
                    }

                    _stopwatch.Stop();
                    double randomTime = (double)((double)_stopwatch.ElapsedTicks / (double)TimeSpan.TicksPerMillisecond)/100.0;                    

                    _stopwatch.Reset();
                    _stopwatch.Start();

                    for (int i = 0; i < 100; i++)
                    {
                        currT = 0;
                        trees = learner.LearnFromKnownHorizon(inputCsv, obsCol, actionCol, modelCol, timeStepCol, observations);
                        trees = learner.FillMissingNodesDifference(trees, epsilon, compatibleFillOutputFile + ".merges", _stopwatch);
                        //trees = learner.FillMissingNodesDifference(trees, epsilon, string.Empty, _stopwatch);
                        StreamWriter sbC = new StreamWriter(compatibleFillOutputFile);
                        foreach (PolicyNode rootNode in trees)
                        {
                            rootNode.RandomFill(new Random(), rootNode.GetTreeLength(), actions);
                            //sbC.WriteLine(rootNode.ToTreeString(string.Empty, currT));
                            //currT++;
                        }

                        foreach (PolicyNode rootNode in trees)
                        {
                            //rootNode.RandomFill(new Random(), rootNode.GetTreeLength(), actions);
                            sbC.WriteLine(rootNode.ToTreeString(string.Empty, currT));
                            currT++;
                        }

                        sbC.Close();
                    }

                    _stopwatch.Stop();
                    double compatibleTime = (double)((double)_stopwatch.ElapsedTicks / (double)TimeSpan.TicksPerMillisecond) / 100.0;                    

                    Console.WriteLine("Timing for " + inputCsv);
                    Console.WriteLine("Random Fill Learning Time = " + randomTime + "(ms)");
                    Console.WriteLine("Compatible Fill Learning Time = " + randomTime + "(ms)");
                    Console.WriteLine();
                    Console.WriteLine();
                    
                }
                else if (args.Count() >= 12)
                {
                    IdidEnterModels enterModels = new IdidEnterModels();
                    string inputDomain = args[7];
                    string outputDomain = args[8];
                    string ajPrefix = args[9];
                    string ojPrefix = args[10];
                    string modPrefix = args[11];

                    enterModels.EnterPolicyTreesIntoIdid(inputDomain, outputDomain, trees.ToList(), ajPrefix, modPrefix, ojPrefix);

                    if (args.Count() >= 15)
                    {
                        string outputMergedNet = args[12];
                        double epsilon = Convert.ToDouble(args[14]);

                        MergeType mType = (MergeType)Enum.Parse(typeof(MergeType), args[13]);

                        string mergesFile = string.Empty;
                        if (args.Count() >= 16)
                        {
                            mergesFile = args[15];
                        }

                        double startTime = DateTime.Now.Millisecond;

                        trees = learner.LearnFromKnownHorizon(inputCsv, obsCol, actionCol, modelCol, timeStepCol, observations);

                        switch (mType)
                        {
                            case MergeType.HOEFF: 
                                trees = learner.FillMissingNodesHoeffding(trees, epsilon, mergesFile);
                                break;
                            case MergeType.SIM: 
                                trees = learner.FillMissingNodesDifference(trees, epsilon, mergesFile, _stopwatch);
                                break;
                        }
                        enterModels.EnterPolicyTreesIntoIdid(inputDomain, outputMergedNet, trees.ToList(), ajPrefix, modPrefix, ojPrefix);
                    }
                }
            }
            else
            {
                Console.WriteLine("Invalid Parameters");
                Console.WriteLine("Press any key to continue ...");
                Console.ReadKey();
            }
        }

        public string[] GetCommandOptions()
        {
            string[] options =
            {
                "[input file] [obs col] [observations] [action col] [model col] [time step col]",
                "[input file] [obs col] [observations] [action col] [model col] [time step col] [input net] [output net] [aj] [oj] [mod]",
                "[input file] [obs col] [observations] [action col] [model col] [time step col] [input net] [output net] [aj] [oj] [mod] [output merged net] [mergeType] [merge value]",
                "[input file] [obs col] [observations] [action col] [model col] [time step col] [input net] [output net] [aj] [oj] [mod] [output merged net] [mergeType] [merge value] [merges file]"
            };

            return options;
        }

        #endregion
    }
}
